#!/usr/bin/env python
#
# Copyright (C) 2011 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ABOUT
#   This script is used to generate the trace implementations of all
#   OpenGL calls. When executed, it reads the specs for the OpenGL calls
#   from the files GLES2/gl2_api.in, GLES2/gl2ext_api.in, GLES_CM/gl_api.in,
#   and GLES_CM/glext_api.in, and generates trace versions for all the 
#   defined functions.
#
# PREREQUISITES
#   To generate C++ files, this script uses the 'pyratemp' template
#   module. The only reason to use pyratemp is that it is extremly
#   simple to install:
#   $ wget http://www.simple-is-better.org/template/pyratemp-current/pyratemp.py
#   Put the file in the GLES_trace/tools folder, or update PYTHONPATH
#   to point to wherever it was downloaded.
#
# USAGE
#   $ cd GLES_trace       - run the program from GLES2_trace folder
#   $ ./tools/genapi.py   - generates a .cpp and .h file
#   $ mv *.cpp *.h src/   - move the generated files into the src folder

import sys
import re
import pyratemp

# Constants corresponding to the protobuf DataType.Type
class DataType:
    def __init__(self, name):
        self.name = name

    def __str__(self):
        if self.name == "pointer":  # pointers map to the INT DataType
            return "INT"
        return self.name.upper()

    def getProtobufCall(self):
        if self.name == "void":
            raise ValueError("Attempt to set void value")
        elif self.name == "char" or self.name == "byte" \
                or self.name == "pointer" or self.name == "enum":
            return "add_intvalue((int)"
        elif self.name == "int":
            return "add_intvalue("
        elif self.name == "float":
            return "add_floatvalue("
        elif self.name == "bool":
            return "add_boolvalue("
        elif self.name == "int64":
            return "add_int64value("
        else:
            raise ValueError("Unknown value type %s" % self.name)

DataType.VOID = DataType("void")
DataType.CHAR = DataType("char")
DataType.BYTE = DataType("byte")
DataType.ENUM = DataType("enum")
DataType.BOOL = DataType("bool")
DataType.INT = DataType("int")
DataType.FLOAT = DataType("float")
DataType.POINTER = DataType("pointer")
DataType.INT64 = DataType("int64")

# mapping of GL types to protobuf DataType
GLPROTOBUF_TYPE_MAP = {
    "GLvoid":DataType.VOID,
    "void":DataType.VOID,
    "GLchar":DataType.CHAR,
    "GLenum":DataType.ENUM,
    "GLboolean":DataType.BOOL,
    "GLbitfield":DataType.INT,
    "GLbyte":DataType.BYTE,
    "GLshort":DataType.INT,
    "GLint":DataType.INT,
    "int":DataType.INT,
    "GLsizei":DataType.INT,
    "GLubyte":DataType.BYTE,
    "GLushort":DataType.INT,
    "GLuint":DataType.INT,
    "GLfloat":DataType.FLOAT,
    "GLclampf":DataType.FLOAT,
    "GLfixed":DataType.INT,
    "GLclampx":DataType.INT,
    "GLsizeiptr":DataType.INT,
    "GLintptr":DataType.INT,
    "GLeglImageOES":DataType.POINTER,
    "GLint64":DataType.INT64,
    "GLuint64":DataType.INT64,
    "GLsync":DataType.POINTER,
}

API_SPECS = [
    ('GL3','../GLES2/gl3_api.in'),
    ('GL3Ext','../GLES2/gl3ext_api.in'),
    ('GL2','../GLES2/gl2_api.in'),
    ('GL2Ext','../GLES2/gl2ext_api.in'),
    ('GL1','../GLES_CM/gl_api.in'),
    ('GL1Ext','../GLES_CM/glext_api.in'),
]

HEADER_LICENSE = """/*
 * Copyright 2011, The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *
 * THIS FILE WAS GENERATED BY A SCRIPT. DO NOT EDIT.
 */
"""

HEADER_INCLUDES = """
#include <cutils/log.h>
#include <utils/Timers.h>
#include <GLES3/gl3.h>

#include "gltrace.pb.h"
#include "gltrace_context.h"
#include "gltrace_fixup.h"
#include "gltrace_transport.h"
"""

HEADER_NAMESPACE_START = """
namespace android {
namespace gltrace {
"""

FOOTER_TEXT = """
}; // namespace gltrace
}; // namespace android
"""

TRACE_CALL_TEMPLATE = pyratemp.Template(
"""$!retType!$ GLTrace_$!func!$($!inputArgList!$) {
    GLMessage glmsg;
    GLTraceContext *glContext = getGLTraceContext();

    glmsg.set_function(GLMessage::$!func!$);
<!--(if len(parsedArgs) > 0)-->
    <!--(for argname, argtype in parsedArgs)-->

    // copy argument $!argname!$
    GLMessage_DataType *arg_$!argname!$ = glmsg.add_args();
    arg_$!argname!$->set_isarray(false);
    arg_$!argname!$->set_type(GLMessage::DataType::$!argtype!$);
    arg_$!argname!$->$!argtype.getProtobufCall()!$$!argname!$);
    <!--(end)-->
<!--(end)-->

    // call function
    nsecs_t wallStartTime = systemTime(SYSTEM_TIME_MONOTONIC);
    nsecs_t threadStartTime = systemTime(SYSTEM_TIME_THREAD);
<!--(if retType != "void")-->
    $!retType!$ retValue = glContext->hooks->gl.$!callsite!$;
<!--(else)-->
    glContext->hooks->gl.$!callsite!$;
<!--(end)-->
    nsecs_t threadEndTime = systemTime(SYSTEM_TIME_THREAD);
    nsecs_t wallEndTime = systemTime(SYSTEM_TIME_MONOTONIC);
<!--(if retType != "void")-->

    // set return value
    GLMessage_DataType *rt = glmsg.mutable_returnvalue();
    rt->set_isarray(false);
    rt->set_type(GLMessage::DataType::$!retDataType!$);
    rt->$!retDataType.getProtobufCall()!$retValue);
<!--(end)-->

    void *pointerArgs[] = {
<!--(for argname, argtype in parsedArgs)-->
    <!--(if argtype == DataType.POINTER)-->
        (void *) $!argname!$,
    <!--(end)-->
<!--(end)-->
<!--(if retDataType == DataType.POINTER)-->
        (void *) retValue,
<!--(end)-->
    };

    fixupGLMessage(glContext, wallStartTime, wallEndTime,
                              threadStartTime, threadEndTime,
                              &glmsg, pointerArgs);
    glContext->traceGLMessage(&glmsg);
<!--(if retType != "void")-->

    return retValue;
<!--(end)-->
}
""")

def getDataTypeFromKw(kw):
    """ Get the data type given declaration.
    All pointer declarations are of type DataType.POINTER

    e.g.: GLvoid -> DataType.VOID"""

    if kw.count('*') > 0:
        return DataType.POINTER
    return GLPROTOBUF_TYPE_MAP.get(kw)

def getNameTypePair(decl):
    """ Split declaration of a variable to a tuple of (variable name, DataType).
    e.g. "const GLChar* varName" -> (varName, POINTER) """
    elements = decl.strip().split(' ')
    name = None
    if len(elements) > 1:
        name = " ".join(elements[-1:]).strip()      # last element is the name
        dataType = " ".join(elements[:-1]).strip()  # everything else is the data type

        # if name is a pointer (e.g. "*ptr"), then remove the "*" from the name
        # and add it to the data type
        pointersInName = name.count("*")            
        if pointersInName > 0:
            name = name.replace("*", "")
            dataType += "*" * pointersInName

        # if name is an array (e.g. "array[10]"), then remove the "[X]" from the name
        # and make the datatype to be a pointer
        arraysInName = name.count("[")
        if arraysInName > 0:
            name = name.split('[')[0]
            dataType += "*"
    else:
        dataType = elements[0]
    return (name, getDataTypeFromKw(dataType))

def parseArgs(arglist):
    """ Parse the argument list into a list of (var name, DataType) tuples """
    args = arglist.split(',')
    args = map(lambda x: x.strip(), args)    # remove unnecessary whitespaces
    argtypelist = map(getNameTypePair, args) # split arg into arg type and arg name
    if len(argtypelist) == 1:
        (name, argtype) = argtypelist[0]
        if argtype == DataType.VOID:
            return []

    return argtypelist

class ApiCall(object):
    """An ApiCall models all information about a single OpenGL API"""

    # Regex to match API_ENTRY specification:
    #       e.g. void API_ENTRY(glActiveTexture)(GLenum texture) {
    # the regex uses a non greedy match (?) to match the first closing paren
    API_ENTRY_REGEX = "(.*)API_ENTRY\(.*?\)\((.*?)\)"

    # Regex to match CALL_GL_API specification:
    #       e.g. CALL_GL_API(glCullFace, mode); 
    #            CALL_GL_API_RETURN(glCreateProgram);
    CALL_GL_API_REGEX = "CALL_GL_API(_RETURN)?\((.*)\);"

    def __init__(self, prefix, apientry, callsite):
        """Construct an ApiCall from its specification.

        The specification is provided by the two arguments:
        prefix: prefix to use for function names
        defn: specification line containing API_ENTRY macro
              e.g: void API_ENTRY(glActiveTexture)(GLenum texture) {
        callsite: specification line containing CALL_GL_API macro
              e.g: CALL_GL_API(glActiveTexture, texture);        
        """
        self.prefix = prefix
        self.ret = self.getReturnType(apientry)
        self.arglist = self.getArgList(apientry)

        # some functions (e.g. __glEGLImageTargetRenderbufferStorageOES), define their
        # names one way in the API_ENTRY and another way in the CALL_GL_API macros.
        # so self.func is reassigned based on what is there in the call site
        self.func = self.getFunc(callsite)
        self.callsite = self.getCallSite(callsite)

    def getReturnType(self, apientry):
        '''Extract the return type from the API_ENTRY specification'''
        m = re.search(self.API_ENTRY_REGEX, apientry)
        if not m:
            raise ValueError("%s does not match API_ENTRY specification %s" 
                             % (apientry, self.API_ENTRY_REGEX))

        return m.group(1).strip()

    def getArgList(self, apientry):
        '''Extract the argument list from the API_ENTRY specification'''
        m = re.search(self.API_ENTRY_REGEX, apientry)
        if not m:
            raise ValueError("%s does not match API_ENTRY specification %s" 
                             % (apientry, self.API_ENTRY_REGEX))

        return m.group(2).strip()

    def parseCallSite(self, callsite):
        m = re.search(self.CALL_GL_API_REGEX, callsite)
        if not m:
            raise ValueError("%s does not match CALL_GL_API specification (%s)"
                             % (callsite, self.CALL_GL_API_REGEX))

        arglist = m.group(2)
        args = arglist.split(',')
        args = map(lambda x: x.strip(), args)

        return args

    def getCallSite(self, callsite):
        '''Extract the callsite from the CALL_GL_API specification'''
        args = self.parseCallSite(callsite)
        return "%s(%s)" % (args[0], ", ".join(args[1:]))

    def getFunc(self, callsite):
        '''Extract the function name from the CALL_GL_API specification'''
        args = self.parseCallSite(callsite)
        return args[0]

    def genDeclaration(self):
        return "%s GLTrace_%s(%s);" % (self.ret, self.func, self.arglist)

    def genCode(self):
        return TRACE_CALL_TEMPLATE(func = self.func, 
                                   retType = self.ret,
                                   retDataType = getDataTypeFromKw(self.ret),
                                   inputArgList = self.arglist,
                                   callsite = self.callsite,
                                   parsedArgs = parseArgs(self.arglist),
                                   DataType=DataType)

def getApis(apiEntryFile, prefix):
    '''Get a list of all ApiCalls in provided specification file'''
    lines = open(apiEntryFile).readlines()

    apis = []
    for i in range(0, len(lines)/3):
        apis.append(ApiCall(prefix, lines[i*3], lines[i*3+1]))

    return apis

def parseAllSpecs(specs):
    apis = []
    for name, specfile in specs:
        a = getApis(specfile, name)
        print 'Parsed %s APIs from %s, # of entries = %d' % (name, specfile, len(a))
        apis.extend(a)
    return apis

def removeDuplicates(apis):
    '''Remove all duplicate function entries.

    The input list contains functions declared in GL1, GL2, and GL3 APIs.
    This will return a list that contains only the first function if there are
    multiple functions with the same name.'''
    uniqs = []
    funcs = set()
    for api in apis:
        if api.func not in funcs:
            uniqs.append(api)
            funcs.add(api.func)

    return uniqs

def genHeaders(apis, fname):
    lines = []
    lines.append(HEADER_LICENSE)
    lines.append(HEADER_NAMESPACE_START)
    prefix = ""
    for api in apis:
        if prefix != api.prefix:
            lines.append("\n// Declarations for %s APIs\n\n" % api.prefix)
            prefix = api.prefix
        lines.append(api.genDeclaration())
        lines.append("\n")
    lines.append(FOOTER_TEXT)

    with open(fname, "w") as f:
        f.writelines(lines)

def genSrcs(apis, fname):
    lines = []
    lines.append(HEADER_LICENSE)
    lines.append(HEADER_INCLUDES)
    lines.append(HEADER_NAMESPACE_START)
    prefix = ""
    for api in apis:
        if prefix != api.prefix:
            lines.append("\n// Definitions for %s APIs\n\n" % api.prefix)
            prefix = api.prefix
        lines.append(api.genCode())
        lines.append("\n")
    lines.append(FOOTER_TEXT)

    with open(fname, "w") as f:
        f.writelines(lines)

if __name__ == '__main__':
    apis = parseAllSpecs(API_SPECS)    # read in all the specfiles
    apis = removeDuplicates(apis)      # remove duplication of functions common to multiple versions
    genHeaders(apis, 'gltrace_api.h')  # generate header file
    genSrcs(apis, 'gltrace_api.cpp')   # generate source file
